{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import random\n",
    "from collections import deque"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "GAMMA = 0.95 # discount factor\n",
    "LEARNING_RATE=0.01\n",
    "class Policy_Gradient():\n",
    "    def __init__(self, env):\n",
    "        # init some parameters\n",
    "        self.time_step = 0\n",
    "        self.state_dim = env.observation_space.shape[0]\n",
    "        self.action_dim = env.action_space.n\n",
    "        self.ep_obs, self.ep_as, self.ep_rs = [], [], []\n",
    "        self.create_softmax_network()\n",
    "\n",
    "        # Init session\n",
    "        self.session = tf.InteractiveSession()\n",
    "        self.session.run(tf.global_variables_initializer())\n",
    "\n",
    "    def create_softmax_network(self):\n",
    "        #create input layer\n",
    "        self.state_input = tf.placeholder(\"float\", [None, self.state_dim])\n",
    "        self.tf_acts = tf.placeholder(tf.int32, [None, ], name=\"actions_num\")\n",
    "        self.tf_vt = tf.placeholder(tf.float32, [None, ], name=\"actions_value\")\n",
    "        w_initializer, b_initializer = tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1)\n",
    "        h1 = tf.layers.dense(inputs=self.state_input, units=16, activation=tf.nn.relu,kernel_initializer=w_initializer,bias_initializer=b_initializer)\n",
    "        h2 = tf.layers.dense(inputs=h1, units=16, activation=tf.nn.relu,kernel_initializer=w_initializer,bias_initializer=b_initializer)\n",
    "        self.softmax_input = tf.layers.dense(inputs=h2, units=self.action_dim,kernel_initializer=w_initializer,bias_initializer=b_initializer)\n",
    "        #create output layer\n",
    "        #softmax output\n",
    "        self.all_act_prob = tf.nn.softmax(self.softmax_input, name='act_prob')\n",
    "        self.neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.softmax_input,\n",
    "                                                                      labels=self.tf_acts)\n",
    "        self.loss = tf.reduce_mean(self.neg_log_prob * self.tf_vt)  # reward guided loss\n",
    "\n",
    "        self.train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(self.loss)\n",
    "\n",
    "    def choose_action(self, observation):\n",
    "        prob_weights = self.session.run(self.all_act_prob, feed_dict={self.state_input: observation[np.newaxis, :]})\n",
    "        action = np.random.choice(range(prob_weights.shape[1]), p=prob_weights.ravel())  # select action w.r.t the actions prob\n",
    "        return action\n",
    "\n",
    "    def store_transition(self, s, a, r):\n",
    "        self.ep_obs.append(s)\n",
    "        self.ep_as.append(a)\n",
    "        self.ep_rs.append(r)\n",
    "\n",
    "    def learn(self):\n",
    "        #蒙特卡罗策略梯度reinforce\n",
    "        discounted_ep_rs = np.zeros_like(self.ep_rs)\n",
    "        running_add = 0\n",
    "        for t in reversed(range(0, len(self.ep_rs))):\n",
    "            running_add = running_add * GAMMA + self.ep_rs[t]\n",
    "            discounted_ep_rs[t] = running_add\n",
    "\n",
    "        discounted_ep_rs -= np.mean(discounted_ep_rs)\n",
    "        discounted_ep_rs /= np.std(discounted_ep_rs)\n",
    "\n",
    "        # train on episode\n",
    "        self.session.run(self.train_op, feed_dict={\n",
    "             self.state_input: np.vstack(self.ep_obs),\n",
    "             self.tf_acts: np.array(self.ep_as),\n",
    "             self.tf_vt: discounted_ep_rs,\n",
    "        })\n",
    "\n",
    "        self.ep_obs, self.ep_as, self.ep_rs = [], [], []    # empty episode data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#define Cartpole env and Hayper params\n",
    "ENV_NAME = 'CartPole-v0'\n",
    "EPISODE = 3000 # Episode limitation\n",
    "STEP = 3000 # Step limitation in an episode\n",
    "TEST = 10 # The number of experiment test every 100 episode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-2-657986570255>:22: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dense instead.\n",
      "WARNING:tensorflow:From D:\\AnaconDA\\lib\\site-packages\\tensorflow\\python\\framework\\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "episode:  0 Evaluation Average Reward: 25.9\n",
      "episode:  100 Evaluation Average Reward: 200.0\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-4-dc61a471344b>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     33\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     34\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'__main__'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 35\u001b[1;33m     \u001b[0mmain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-4-dc61a471344b>\u001b[0m in \u001b[0;36mmain\u001b[1;34m()\u001b[0m\n\u001b[0;32m     23\u001b[0m                 \u001b[0mstate\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     24\u001b[0m                 \u001b[1;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mSTEP\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 25\u001b[1;33m                     \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     26\u001b[0m                     \u001b[0maction\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0magent\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mchoose_action\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# direct action for test\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     27\u001b[0m                     \u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mreward\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mdone\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0m_\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0maction\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\AnaconDA\\lib\\site-packages\\gym\\core.py\u001b[0m in \u001b[0;36mrender\u001b[1;34m(self, mode, **kwargs)\u001b[0m\n\u001b[0;32m    247\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    248\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mrender\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'human'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 249\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    250\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    251\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\AnaconDA\\lib\\site-packages\\gym\\envs\\classic_control\\cartpole.py\u001b[0m in \u001b[0;36mrender\u001b[1;34m(self, mode)\u001b[0m\n\u001b[0;32m    186\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpoletrans\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mset_rotation\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    187\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 188\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mviewer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mreturn_rgb_array\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmode\u001b[0m\u001b[1;33m==\u001b[0m\u001b[1;34m'rgb_array'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    189\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    190\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\AnaconDA\\lib\\site-packages\\gym\\envs\\classic_control\\rendering.py\u001b[0m in \u001b[0;36mrender\u001b[1;34m(self, return_rgb_array)\u001b[0m\n\u001b[0;32m    112\u001b[0m             \u001b[0marr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0marr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mheight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbuffer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwidth\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m4\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    113\u001b[0m             \u001b[0marr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0marr\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 114\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mflip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    115\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0monetime_geoms\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    116\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0marr\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mreturn_rgb_array\u001b[0m \u001b[1;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0misopen\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\AnaconDA\\lib\\site-packages\\pyglet\\window\\win32\\__init__.py\u001b[0m in \u001b[0;36mflip\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    319\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mflip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    320\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdraw_mouse_cursor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 321\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mflip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    322\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    323\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mset_location\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\AnaconDA\\lib\\site-packages\\pyglet\\gl\\win32.py\u001b[0m in \u001b[0;36mflip\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    224\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    225\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mflip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 226\u001b[1;33m         \u001b[0m_gdi32\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSwapBuffers\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhdc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    227\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    228\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mget_vsync\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "def main():\n",
    "    env = gym.make(ENV_NAME)\n",
    "    agent = Policy_Gradient(env)\n",
    "\n",
    "    for episode in range(EPISODE):\n",
    "        state = env.reset()\n",
    "        for step in range(STEP):\n",
    "            action = agent.choose_action(state) # 根据概率进行动作的选择\n",
    "            next_state,reward,done,_ = env.step(action)\n",
    "            agent.store_transition(state, action, reward) #存储阶段T中的所有时间段\n",
    "            state = next_state\n",
    "            #基于策略的就是要在每一个时间段之后进行训练，这么考虑所谓按策略就是按照成功或是失败，只要最后的结果不要中间的过程。\n",
    "            #不管中间是好是坏只要最后成功了就是希望最高回报，失败就是低回报\n",
    "            if done:\n",
    "                agent.learn()\n",
    "                break\n",
    "\n",
    "    # Test every 100 episodes\n",
    "        if episode % 100 == 0:\n",
    "            total_reward = 0\n",
    "            for i in range(TEST):\n",
    "                state = env.reset()\n",
    "                for j in range(STEP):\n",
    "                    env.render()\n",
    "                    action = agent.choose_action(state) # direct action for test\n",
    "                    state,reward,done,_ = env.step(action)\n",
    "                    total_reward += reward\n",
    "                    if done:\n",
    "                        break\n",
    "            ave_reward = total_reward/TEST\n",
    "            print ('episode: ',episode,'Evaluation Average Reward:',ave_reward)\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
